{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "720f2b62",
   "metadata": {},
   "source": [
    "# Stacking"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b365a02",
   "metadata": {},
   "source": [
    "## 先说结论，该数据集（fetch_covtype）Stacking的方法相比Blending和线性加权更好\n",
    "比赛中我们常用线性加权作为最终的融合方式，我们同样也会好奇怎样的线性加权权重更好，下面也会举例子\n",
    "参考：https://github.com/rushter/heamy/tree/master/examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc8fecb1",
   "metadata": {},
   "source": [
    "通过对训练集进行五折验证，将验证结果作为第二层的训练和测试集合\n",
    "<img src=\"assets/stacking.jpg\" width=\"50%\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "18a12000",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Collecting heamy\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/20/32/2f3e1efa38a8e34f790d90b6d49ef06ab812181ae896c50e89b8750fa5a0/heamy-0.0.7.tar.gz (30 kB)\n",
      "Requirement already satisfied: scikit-learn>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (0.24.1)\n",
      "Requirement already satisfied: pandas>=0.17.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.2.4)\n",
      "Requirement already satisfied: six>=1.10.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.15.0)\n",
      "Requirement already satisfied: scipy>=0.16.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.6.2)\n",
      "Requirement already satisfied: numpy>=1.7.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from heamy) (1.19.5)\n",
      "Requirement already satisfied: pytz>=2017.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2021.1)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in d:\\programdata\\anaconda3\\lib\\site-packages (from pandas>=0.17.0->heamy) (2.8.1)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (2.1.0)\n",
      "Requirement already satisfied: joblib>=0.11 in d:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn>=0.17.0->heamy) (1.0.1)\n",
      "Building wheels for collected packages: heamy\n",
      "  Building wheel for heamy (setup.py): started\n",
      "  Building wheel for heamy (setup.py): finished with status 'done'\n",
      "  Created wheel for heamy: filename=heamy-0.0.7-py2.py3-none-any.whl size=15353 sha256=e3ba65b34e2bdee3b90b45b637e28836afdbdb0c9547f76b36fe10d17f8aba8f\n",
      "  Stored in directory: c:\\users\\administrator\\appdata\\local\\pip\\cache\\wheels\\6e\\f1\\7d\\048e558da94f495a0ed0d9c09d312e73eb176a092e36774ec2\n",
      "Successfully built heamy\n",
      "Installing collected packages: heamy\n",
      "Successfully installed heamy-0.0.7\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install heamy  # 安装相关包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "69632c6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "print(sys.version)  # 版本信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ca421279",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "\n",
    "from heamy.dataset import Dataset\n",
    "from heamy.estimator import Classifier \n",
    "from heamy.pipeline import ModelsPipeline\n",
    "# 导入相关模型，没有的pip install xxx 即可\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "import xgboost as xgb \n",
    "import lightgbm as lgb \n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.preprocessing import OrdinalEncoder\n",
    "from sklearn.metrics import log_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2592fbbd",
   "metadata": {},
   "source": [
    "## 准备数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9a0fabe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_covtype\n",
    "data = fetch_covtype()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5bd75178",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "七分类任务，处理前： [1 2 3 4 5 6 7]\n",
      "[5 5 2 ... 3 3 3]\n",
      "七分类任务，处理后： [0. 1. 2. 3. 4. 5. 6.]\n",
      "[4. 4. 1. ... 2. 2. 2.]\n"
     ]
    }
   ],
   "source": [
    "# 预处理\n",
    "X, y = data['data'], data['target']\n",
    "# 由于模型标签需要从0开始，所以数字需要全部减1\n",
    "print('七分类任务，处理前：',np.unique(y))\n",
    "print(y)\n",
    "ord = OrdinalEncoder()\n",
    "y = ord.fit_transform(y.reshape(-1, 1))\n",
    "y = y.reshape(-1, )\n",
    "print('七分类任务，处理后：',np.unique(y))\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "23d9778c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(435759, 54)\n",
      "(145253, 54)\n"
     ]
    }
   ],
   "source": [
    "# 切分训练和测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=42)\n",
    "print(X_train.shape)\n",
    "print(X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "eac48668",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset(5c3ccfb5c81451d098565ef5e7e36ac5)\n"
     ]
    }
   ],
   "source": [
    "# 创建数据集\n",
    "'''use_cache : bool, default True\n",
    "    If use_cache=True then preprocessing step will be cached until function codeis changed.'''\n",
    "dataset = Dataset(X_train=X_train, y_train=y_train, X_test=X_test, y_test=None,use_cache=True)  # 注意这里的y_test=None，即不存在数据泄露\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fba3f975",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2.833e+03, 2.580e+02, 2.600e+01, ..., 0.000e+00, 0.000e+00,\n",
       "        0.000e+00],\n",
       "       [3.008e+03, 4.500e+01, 2.000e+00, ..., 0.000e+00, 0.000e+00,\n",
       "        0.000e+00],\n",
       "       [2.949e+03, 0.000e+00, 1.100e+01, ..., 0.000e+00, 0.000e+00,\n",
       "        0.000e+00],\n",
       "       ...,\n",
       "       [3.153e+03, 2.870e+02, 1.700e+01, ..., 0.000e+00, 0.000e+00,\n",
       "        0.000e+00],\n",
       "       [3.065e+03, 3.480e+02, 2.100e+01, ..., 0.000e+00, 0.000e+00,\n",
       "        0.000e+00],\n",
       "       [3.021e+03, 2.600e+01, 1.600e+01, ..., 0.000e+00, 0.000e+00,\n",
       "        0.000e+00]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 处理后的数据集\n",
    "dataset.X_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4517ea1",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e8393e73",
   "metadata": {},
   "outputs": [],
   "source": [
    "def xgb_model(X_train, y_train, X_test, y_test):\n",
    "    \"\"\"参数必须为X_train,y_train,X_test,y_test\"\"\"\n",
    "    # 可以内置参数\n",
    "    params = {'objective': 'multi:softprob',\n",
    "              \"eval_metric\": 'mlogloss',\n",
    "              \"verbosity\": 0,\n",
    "              'num_class': 7,\n",
    "              'nthread': -1}\n",
    "    dtrain = xgb.DMatrix(X_train, y_train)\n",
    "    dtest = xgb.DMatrix(X_test)\n",
    "    model = xgb.train(params, dtrain, num_boost_round=300)\n",
    "    predict = model.predict(dtest)\n",
    "    return predict  # 返回值必须为X_test的预测\n",
    "\n",
    "\n",
    "def lgb_model(X_train, y_train, X_test, y_test,**parameters):\n",
    "    # 也可以开放参数接口\n",
    "    if parameters is None:\n",
    "        parameters = {}\n",
    "    lgb_train = lgb.Dataset(X_train, y_train)\n",
    "    model = lgb.train(params=parameters, train_set=lgb_train,num_boost_round=300)\n",
    "    predict = model.predict(X_test)\n",
    "    return predict\n",
    "\n",
    "\n",
    "def rf_model(X_train, y_train, X_test, y_test):\n",
    "    params = {\"n_estimators\": 100, \"n_jobs\": -1}\n",
    "    model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
    "    predict = model.predict_proba(X_test)\n",
    "    return predict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0715cf6e",
   "metadata": {},
   "source": [
    "## 构建和训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "78ab0083",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\"objective\": \"multiclass\",\n",
    "          \"num_class\": 7,\n",
    "          \"n_jobs\": -1,\n",
    "          \"verbose\": -4, \n",
    "          \"metric\": (\"multi_logloss\",)}\n",
    "\n",
    "model_xgb = Classifier(dataset=dataset, estimator=xgb_model, name='xgb',use_cache=False)\n",
    "model_lgb = Classifier(dataset=dataset, estimator=lgb_model, name='lgb',parameters=params,use_cache=False)\n",
    "model_rf = Classifier(dataset=dataset, estimator=rf_model,name='rf',use_cache=False)\n",
    "\n",
    "pipeline = ModelsPipeline(model_xgb, model_lgb, model_rf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "173ef0f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Score (log_loss): 0.18646435443714865\n",
      "Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
      "Wall time: 14min 19s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([2.53464919e-01, 1.48562205e-20, 7.46535081e-01])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "pipeline.find_weights(scorer=log_loss)  # 输出最优权重组合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "80726d19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 1h 39min 20s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 5折训练构建5折模型特征集，这里比较耗时\n",
    "stack_ds = pipeline.stack(k=5,seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b25bba3c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>xgb_0</th>\n",
       "      <th>xgb_1</th>\n",
       "      <th>xgb_2</th>\n",
       "      <th>xgb_3</th>\n",
       "      <th>xgb_4</th>\n",
       "      <th>xgb_5</th>\n",
       "      <th>xgb_6</th>\n",
       "      <th>lgb_0</th>\n",
       "      <th>lgb_1</th>\n",
       "      <th>lgb_2</th>\n",
       "      <th>...</th>\n",
       "      <th>lgb_4</th>\n",
       "      <th>lgb_5</th>\n",
       "      <th>lgb_6</th>\n",
       "      <th>rf_0</th>\n",
       "      <th>rf_1</th>\n",
       "      <th>rf_2</th>\n",
       "      <th>rf_3</th>\n",
       "      <th>rf_4</th>\n",
       "      <th>rf_5</th>\n",
       "      <th>rf_6</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.177179</td>\n",
       "      <td>0.818728</td>\n",
       "      <td>2.185222e-07</td>\n",
       "      <td>9.264143e-09</td>\n",
       "      <td>4.090067e-03</td>\n",
       "      <td>1.725062e-06</td>\n",
       "      <td>1.048052e-06</td>\n",
       "      <td>0.179625</td>\n",
       "      <td>0.804684</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>...</td>\n",
       "      <td>1.562435e-02</td>\n",
       "      <td>6.370849e-05</td>\n",
       "      <td>1.004259e-08</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.005155</td>\n",
       "      <td>0.994845</td>\n",
       "      <td>7.055579e-10</td>\n",
       "      <td>1.326343e-08</td>\n",
       "      <td>6.331572e-09</td>\n",
       "      <td>1.435787e-09</td>\n",
       "      <td>1.603579e-10</td>\n",
       "      <td>0.008114</td>\n",
       "      <td>0.991886</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.87</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.293492</td>\n",
       "      <td>0.706508</td>\n",
       "      <td>3.650662e-10</td>\n",
       "      <td>1.017633e-09</td>\n",
       "      <td>8.823530e-09</td>\n",
       "      <td>6.384080e-10</td>\n",
       "      <td>2.823794e-08</td>\n",
       "      <td>0.831445</td>\n",
       "      <td>0.168555</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>4.999034e-07</td>\n",
       "      <td>4.015190e-09</td>\n",
       "      <td>4.997854e-09</td>\n",
       "      <td>0.63</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.478112</td>\n",
       "      <td>0.521816</td>\n",
       "      <td>3.207779e-06</td>\n",
       "      <td>2.878019e-08</td>\n",
       "      <td>1.076500e-08</td>\n",
       "      <td>2.230641e-06</td>\n",
       "      <td>6.630235e-05</td>\n",
       "      <td>0.465733</td>\n",
       "      <td>0.534184</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>8.245405e-05</td>\n",
       "      <td>0.55</td>\n",
       "      <td>0.45</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.992430</td>\n",
       "      <td>0.006652</td>\n",
       "      <td>1.233117e-05</td>\n",
       "      <td>1.887496e-07</td>\n",
       "      <td>1.569583e-06</td>\n",
       "      <td>5.604260e-07</td>\n",
       "      <td>9.037877e-04</td>\n",
       "      <td>0.932050</td>\n",
       "      <td>0.043451</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>2.449972e-02</td>\n",
       "      <td>0.97</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      xgb_0     xgb_1         xgb_2         xgb_3         xgb_4         xgb_5  \\\n",
       "0  0.177179  0.818728  2.185222e-07  9.264143e-09  4.090067e-03  1.725062e-06   \n",
       "1  0.005155  0.994845  7.055579e-10  1.326343e-08  6.331572e-09  1.435787e-09   \n",
       "2  0.293492  0.706508  3.650662e-10  1.017633e-09  8.823530e-09  6.384080e-10   \n",
       "3  0.478112  0.521816  3.207779e-06  2.878019e-08  1.076500e-08  2.230641e-06   \n",
       "4  0.992430  0.006652  1.233117e-05  1.887496e-07  1.569583e-06  5.604260e-07   \n",
       "\n",
       "          xgb_6     lgb_0     lgb_1     lgb_2  ...         lgb_4  \\\n",
       "0  1.048052e-06  0.179625  0.804684  0.000003  ...  1.562435e-02   \n",
       "1  1.603579e-10  0.008114  0.991886  0.000000  ...  0.000000e+00   \n",
       "2  2.823794e-08  0.831445  0.168555  0.000000  ...  4.999034e-07   \n",
       "3  6.630235e-05  0.465733  0.534184  0.000000  ...  0.000000e+00   \n",
       "4  9.037877e-04  0.932050  0.043451  0.000000  ...  0.000000e+00   \n",
       "\n",
       "          lgb_5         lgb_6  rf_0  rf_1  rf_2  rf_3  rf_4  rf_5  rf_6  \n",
       "0  6.370849e-05  1.004259e-08  0.03  0.96   0.0   0.0  0.01   0.0   0.0  \n",
       "1  0.000000e+00  0.000000e+00  0.13  0.87   0.0   0.0  0.00   0.0   0.0  \n",
       "2  4.015190e-09  4.997854e-09  0.63  0.37   0.0   0.0  0.00   0.0   0.0  \n",
       "3  0.000000e+00  8.245405e-05  0.55  0.45   0.0   0.0  0.00   0.0   0.0  \n",
       "4  0.000000e+00  2.449972e-02  0.97  0.03   0.0   0.0  0.00   0.0   0.0  \n",
       "\n",
       "[5 rows x 21 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型输出的训练集，7个特征对应7个标签的预测概率\n",
    "stack_ds.X_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "835205e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>xgb_0</th>\n",
       "      <th>xgb_1</th>\n",
       "      <th>xgb_2</th>\n",
       "      <th>xgb_3</th>\n",
       "      <th>xgb_4</th>\n",
       "      <th>xgb_5</th>\n",
       "      <th>xgb_6</th>\n",
       "      <th>lgb_0</th>\n",
       "      <th>lgb_1</th>\n",
       "      <th>lgb_2</th>\n",
       "      <th>...</th>\n",
       "      <th>lgb_4</th>\n",
       "      <th>lgb_5</th>\n",
       "      <th>lgb_6</th>\n",
       "      <th>rf_0</th>\n",
       "      <th>rf_1</th>\n",
       "      <th>rf_2</th>\n",
       "      <th>rf_3</th>\n",
       "      <th>rf_4</th>\n",
       "      <th>rf_5</th>\n",
       "      <th>rf_6</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.991493</td>\n",
       "      <td>0.000210</td>\n",
       "      <td>4.796058e-06</td>\n",
       "      <td>6.178684e-08</td>\n",
       "      <td>6.947614e-07</td>\n",
       "      <td>2.410490e-09</td>\n",
       "      <td>8.291459e-03</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000e+00</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.024731</td>\n",
       "      <td>0.964372</td>\n",
       "      <td>6.387765e-04</td>\n",
       "      <td>4.205048e-08</td>\n",
       "      <td>1.006575e-02</td>\n",
       "      <td>1.879628e-04</td>\n",
       "      <td>4.830114e-06</td>\n",
       "      <td>0.065073</td>\n",
       "      <td>0.877354</td>\n",
       "      <td>0.001678</td>\n",
       "      <td>...</td>\n",
       "      <td>5.530599e-02</td>\n",
       "      <td>0.000589</td>\n",
       "      <td>8.601484e-11</td>\n",
       "      <td>0.09</td>\n",
       "      <td>0.80</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.000780</td>\n",
       "      <td>0.979776</td>\n",
       "      <td>8.593459e-04</td>\n",
       "      <td>1.267791e-07</td>\n",
       "      <td>1.710379e-02</td>\n",
       "      <td>1.477527e-03</td>\n",
       "      <td>2.521008e-06</td>\n",
       "      <td>0.005164</td>\n",
       "      <td>0.933849</td>\n",
       "      <td>0.016355</td>\n",
       "      <td>...</td>\n",
       "      <td>3.657553e-02</td>\n",
       "      <td>0.008057</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.97</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.042695</td>\n",
       "      <td>0.957304</td>\n",
       "      <td>2.283268e-08</td>\n",
       "      <td>4.387427e-08</td>\n",
       "      <td>4.175481e-07</td>\n",
       "      <td>4.406019e-08</td>\n",
       "      <td>6.909629e-10</td>\n",
       "      <td>0.054392</td>\n",
       "      <td>0.945608</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>4.285638e-08</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.000457</td>\n",
       "      <td>0.999334</td>\n",
       "      <td>3.366338e-06</td>\n",
       "      <td>4.893879e-08</td>\n",
       "      <td>2.045808e-04</td>\n",
       "      <td>7.889498e-07</td>\n",
       "      <td>1.415576e-09</td>\n",
       "      <td>0.001367</td>\n",
       "      <td>0.995857</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>2.776106e-03</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      xgb_0     xgb_1         xgb_2         xgb_3         xgb_4         xgb_5  \\\n",
       "0  0.991493  0.000210  4.796058e-06  6.178684e-08  6.947614e-07  2.410490e-09   \n",
       "1  0.024731  0.964372  6.387765e-04  4.205048e-08  1.006575e-02  1.879628e-04   \n",
       "2  0.000780  0.979776  8.593459e-04  1.267791e-07  1.710379e-02  1.477527e-03   \n",
       "3  0.042695  0.957304  2.283268e-08  4.387427e-08  4.175481e-07  4.406019e-08   \n",
       "4  0.000457  0.999334  3.366338e-06  4.893879e-08  2.045808e-04  7.889498e-07   \n",
       "\n",
       "          xgb_6     lgb_0     lgb_1     lgb_2  ...         lgb_4     lgb_5  \\\n",
       "0  8.291459e-03  0.000000  0.000000  0.000000  ...  0.000000e+00  0.000000   \n",
       "1  4.830114e-06  0.065073  0.877354  0.001678  ...  5.530599e-02  0.000589   \n",
       "2  2.521008e-06  0.005164  0.933849  0.016355  ...  3.657553e-02  0.008057   \n",
       "3  6.909629e-10  0.054392  0.945608  0.000000  ...  4.285638e-08  0.000000   \n",
       "4  1.415576e-09  0.001367  0.995857  0.000000  ...  2.776106e-03  0.000000   \n",
       "\n",
       "          lgb_6  rf_0  rf_1  rf_2  rf_3  rf_4  rf_5  rf_6  \n",
       "0  1.000000e+00  0.99  0.00  0.00   0.0  0.00  0.00  0.01  \n",
       "1  8.601484e-11  0.09  0.80  0.05   0.0  0.06  0.00  0.00  \n",
       "2  0.000000e+00  0.01  0.97  0.00   0.0  0.01  0.01  0.00  \n",
       "3  0.000000e+00  0.04  0.96  0.00   0.0  0.00  0.00  0.00  \n",
       "4  0.000000e+00  0.00  0.99  0.00   0.0  0.01  0.00  0.00  \n",
       "\n",
       "[5 rows x 21 columns]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型输出的测试集，7个特征对应7个标签的预测概率\n",
    "stack_ds.X_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b9db35dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 1min 53s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 用lr做最后一层\n",
    "stacker = Classifier(dataset=stack_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
    "predict_stack = stacker.predict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "a4a48219",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9.99137967e-01 4.76574513e-04 3.32764186e-07 ... 3.03237700e-06\n",
      "  1.62161211e-06 3.79326620e-04]\n",
      " [2.00732175e-02 9.71682830e-01 1.51484266e-03 ... 5.69138554e-03\n",
      "  9.40725428e-04 9.60822625e-05]\n",
      " [5.54556002e-03 9.91048437e-01 8.04840682e-04 ... 2.11437934e-03\n",
      "  4.56463787e-04 3.00919502e-05]\n",
      " ...\n",
      " [4.60179790e-06 1.78298095e-03 9.91553958e-01 ... 7.26752933e-04\n",
      "  3.79135124e-03 1.53584401e-05]\n",
      " [9.96307096e-01 2.43558944e-03 1.01596361e-07 ... 3.94985596e-05\n",
      "  7.41569805e-06 1.20819024e-03]\n",
      " [5.34671504e-05 7.62534718e-04 5.58323657e-03 ... 2.11410908e-04\n",
      "  9.91805379e-01 1.69502656e-05]]\n"
     ]
    }
   ],
   "source": [
    "print(predict_stack)  # stacking后的结果"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1372d4f8",
   "metadata": {},
   "source": [
    "## 验证结果"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52ef71d4",
   "metadata": {},
   "source": [
    "### 单模分数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a28806a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9254473229468583\n",
      "0.8412562907478676\n",
      "0.9535087055000585\n"
     ]
    }
   ],
   "source": [
    "print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, :7].values, axis=1),y_test))  # XGB\n",
    "print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 7:14].values, axis=1),y_test))  # LGB\n",
    "print(accuracy_score(np.argmax(stack_ds.X_test.iloc[:, 14:].values, axis=1),y_test))  # RF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "1db92fec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9537840870756542\n",
      "Wall time: 50.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 测试单模运行结果是否一致\n",
    "rf_predict = rf_model(X_train, y_train, X_test, None)\n",
    "print(accuracy_score(np.argmax(rf_predict, axis=1),y_test))  # RF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e9423ce",
   "metadata": {},
   "source": [
    "### 线性加权分数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "d2b50ba4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "主观根据结果blending： 0.9504106627746071\n",
      "根据最优权重的线性加权： 0.9530749795184953\n"
     ]
    }
   ],
   "source": [
    "# blending的分数\n",
    "xgb_t = stack_ds.X_test.iloc[:, :7].values\n",
    "lgb_t = stack_ds.X_test.iloc[:, 7:14].values\n",
    "rf_t = stack_ds.X_test.iloc[:, 14:].values\n",
    "\n",
    "# 根据分数好坏随机定\n",
    "result = 0.2*xgb_t+0.1*lgb_t+0.7*rf_t\n",
    "print('主观根据结果blending：', accuracy_score(np.argmax(result, axis=1), y_test))\n",
    "# 根据上面提供的最优权重 Best Weights: [2.53464919e-01 1.48562205e-20 7.46535081e-01]\n",
    "result =  2.53464919e-01*xgb_t+1.48562205e-20*lgb_t+7.46535081e-01*rf_t\n",
    "print('根据最优权重的线性加权：',accuracy_score(np.argmax(result, axis=1), y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8ce00b2",
   "metadata": {},
   "source": [
    "可以观察到最优权重比我们主观选权重更优，单对比单模结果反而下降了"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4a12d6d",
   "metadata": {},
   "source": [
    "### Blending的分数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ffe1cf4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 14min 10s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "blend_ds = pipeline.blend(seed=111)\n",
    "blender = Classifier(dataset=blend_ds, estimator=LogisticRegression, parameters={\"solver\": 'lbfgs', \"max_iter\": 1000},use_cache=False)\n",
    "predict_blend = blender.predict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "506adffb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9546859617357301\n"
     ]
    }
   ],
   "source": [
    "print(accuracy_score(np.argmax(predict_blend, axis=1), y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e84c19ae",
   "metadata": {},
   "source": [
    "使用Blending的分数有所提升"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8daf1e3",
   "metadata": {},
   "source": [
    "### Stacking的分数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "4930b407",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9589887988544127\n"
     ]
    }
   ],
   "source": [
    "print(accuracy_score(np.argmax(predict_stack, axis=1), y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5ad3c4b",
   "metadata": {},
   "source": [
    "可以明显看到提升的效果"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6311fbd7",
   "metadata": {},
   "source": [
    "## 再说结论，该数据集（fetch_covtype）Stacking的方法相比Blending和线性加权更好"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
